{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sentiment Analysis with MXNet and Gluon\n",
    "\n",
    "This tutorial will show how to train and test a Sentiment Analysis (Text Classification) model on SageMaker using MXNet and the Gluon API.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import boto3\n",
    "import sagemaker\n",
    "from sagemaker.mxnet import MXNet\n",
    "from sagemaker import get_execution_role\n",
    "\n",
    "sagemaker_session = sagemaker.Session()\n",
    "\n",
    "role = get_execution_role()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download training and test data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we will train the **Sentiment Analysis** model on [SST-2 dataset (Stanford Sentiment Treebank 2)](https://nlp.stanford.edu/sentiment/index.html). The dataset consists of movie reviews with one sentence per review. Classification involves detecting positive/negative reviews.  \n",
    "We will download the preprocessed version of this dataset from the links below. Each line in the dataset has space separated tokens, the first token being the label: 1 for positive and 0 for negative."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "mkdir data\n",
    "curl https://raw.githubusercontent.com/saurabh3949/Text-Classification-Datasets/master/stsa.binary.phrases.train > data/train\n",
    "curl https://raw.githubusercontent.com/saurabh3949/Text-Classification-Datasets/master/stsa.binary.test > data/test "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Uploading the data\n",
    "\n",
    "We use the `sagemaker.Session.upload_data` function to upload our datasets to an S3 location. The return value `inputs` identifies the location -- we will use this later when we start the training job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = sagemaker_session.upload_data(path='data', key_prefix='data/DEMO-sentiment')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implement the training function\n",
    "\n",
    "We need to provide a training script that can run on the SageMaker platform. The training scripts are essentially the same as one you would write for local training, except that you need to provide a `train` function. When SageMaker calls your function, it will pass in arguments that describe the training environment. Check the script below to see how this works.\n",
    "\n",
    "The script here is a simplified implementation of [\"Bag of Tricks for Efficient Text Classification\"](https://arxiv.org/abs/1607.01759), as implemented by Facebook's [FastText](https://github.com/facebookresearch/fastText/) for text classification. The model maps each word to a vector and averages vectors of all the words in a sentence to form a hidden representation of the sentence, which is inputted to a softmax classification layer. Please refer to the paper for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat 'sentiment.py'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the training script on SageMaker\n",
    "\n",
    "The ```MXNet``` class allows us to run our training function on SageMaker infrastructure. We need to configure it with our training script, an IAM role, the number of training instances, and the training instance type. In this case we will run our training job on a single c4.2xlarge instance. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = MXNet('sentiment.py',\n",
    "          role=role,\n",
    "          train_instance_count=1,\n",
    "          train_instance_type='ml.c4.xlarge',\n",
    "          framework_version='1.4.0',\n",
    "          py_version='py2',\n",
    "          distributions={'parameter_server': {'enabled': True}},\n",
    "          hyperparameters={'batch-size': 8,\n",
    "                           'epochs': 2,\n",
    "                           'learning-rate': 0.01,\n",
    "                           'embedding-size': 50, \n",
    "                           'log-interval': 1000})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After we've constructed our `MXNet` object, we can fit it using the data we uploaded to S3. SageMaker makes sure our data is available in the local filesystem, so our training script can simply read the data from disk.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.fit(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As can be seen from the logs, we get > 80% accuracy on the test set using the above hyperparameters.\n",
    "\n",
    "After training, we use the MXNet object to build and deploy an MXNetPredictor object. This creates a SageMaker endpoint that we can use to perform inference. \n",
    "\n",
    "This allows us to perform inference on json encoded string array. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "predictor = m.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The predictor runs inference on our input data and returns the predicted sentiment (1 for positive and 0 for negative)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data = [\"this movie was extremely good .\",\n",
    "        \"the plot was very boring .\",\n",
    "        \"this film is so slick , superficial and trend-hoppy .\",\n",
    "        \"i just could not watch it till the end .\",\n",
    "        \"the movie was so enthralling !\"]\n",
    "\n",
    "response = predictor.predict(data)\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup\n",
    "\n",
    "After you have finished with this example, remember to delete the prediction endpoint to release the instance(s) associated with it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.delete_endpoint()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_mxnet_p27)",
   "language": "python",
   "name": "conda_mxnet_p27"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  },
  "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
